{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 附一 基于 LangChain 自定义 LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LangChain 为基于 LLM 开发自定义应用提供了高效的开发框架，便于开发者迅速地激发 LLM 的强大能力，搭建 LLM 应用。LangChain 也同样支持多种大模型，内置了 OpenAI、LLAMA 等大模型的调用接口。但是，LangChain 并没有内置所有大模型，它通过允许用户自定义 LLM 类型，来提供强大的可扩展性。\n",
    "\n",
    "在本部分，我们以百度文心大模型为例，讲述如何基于 LangChain 自定义 LLM，让我们基于 LangChain 搭建的应用能够支持百度文心、讯飞星火等国内大模型。\n",
    "\n",
    "本部分涉及相对更多 LangChain、大模型调用的技术细节，有精力同学可以学习部署，如无精力可以直接使用后续代码来支持调用。\n",
    "\n",
    "要实现自定义 LLM，需要定义一个自定义类继承自 LangChain 的 LLM 基类，然后定义两个函数：  \n",
    "① _call 方法，其接受一个字符串，并返回一个字符串，即模型的核心调用；  \n",
    "② _identifying_params 方法，用于打印 LLM 信息。  \n",
    "\n",
    "首先我们导入所需的第三方库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import time\n",
    "from typing import Any, List, Mapping, Optional, Dict, Union, Tuple\n",
    "import requests\n",
    "from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
    "from langchain.llms.base import LLM\n",
    "from langchain.utils import get_from_dict_or_env\n",
    "from pydantic import Field, model_validator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于百度文心使用双重秘钥进行认证，用户需要先基于 API_Key 与 Secret_Key 来获取 access_token，再使用 access_token 来实现对模型的调用（详见《3. 调用百度文心》），因此我们需要先定义一个 get_access_token 方法来获取 access_token："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_access_token(api_key : str, secret_key : str):\n",
    "    \"\"\"\n",
    "    使用 API Key，Secret Key 获取access_token，替换下列示例中的应用API Key、应用Secret Key\n",
    "    \"\"\"\n",
    "    # 指定网址\n",
    "    url = f\"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}\"\n",
    "    # 设置 POST 访问\n",
    "    payload = json.dumps(\"\")\n",
    "    headers = {\n",
    "        'Content-Type': 'application/json',\n",
    "        'Accept': 'application/json'\n",
    "    }\n",
    "    # 通过 POST 访问获取账户对应的 access_token\n",
    "    response = requests.request(\"POST\", url, headers=headers, data=payload)\n",
    "    return response.json().get(\"access_token\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们定义一个继承自 LLM 类的自定义 LLM 类："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 继承自 langchain.llms.base.LLM\n",
    "class Wenxin_LLM(LLM):\n",
    "    # 原生接口地址\n",
    "    url = \"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant\"\n",
    "    # 默认选用 ERNIE-Bot-turbo 模型，即目前一般所说的百度文心大模型\n",
    "    model_name: str = Field(default=\"ERNIE-Bot-turbo\", alias=\"model\")\n",
    "    # 访问时延上限\n",
    "    request_timeout: Optional[Union[float, Tuple[float, float]]] = None\n",
    "    # 温度系数\n",
    "    temperature: float = 0.1\n",
    "    # API_Key\n",
    "    api_key: str = None\n",
    "    # Secret_Key\n",
    "    secret_key : str = None\n",
    "    # access_token\n",
    "    access_token: str = None\n",
    "    # 必备的可选参数\n",
    "    model_kwargs: Dict[str, Any] = Field(default_factory=dict)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上述初始化涵盖了我们平时常用的参数，也可以根据实际需求与文心的 API 加入更多的参数。\n",
    "\n",
    "接下来我们实现一个初始化方法 init_access_token，当模型的 access_token 为空时调用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_access_token(self):\n",
    "    if self.api_key != None and self.secret_key != None:\n",
    "        # 两个 Key 均非空才可以获取 access_token\n",
    "        try:\n",
    "            self.access_token = get_access_token(self.api_key, self.secret_key)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            print(\"获取 access_token 失败，请检查 Key\")\n",
    "    else:\n",
    "        print(\"API_Key 或 Secret_Key 为空，请检查 Key\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们实现核心的方法——调用模型 API："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _call(self, prompt : str, stop: Optional[List[str]] = None,\n",
    "                run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "                **kwargs: Any):\n",
    "    # 除 prompt 参数外，其他参数并没有被用到，但当我们通过 LangChain 调用时会传入这些参数，因此必须设置\n",
    "    # 如果 access_token 为空，初始化 access_token\n",
    "    if self.access_token == None:\n",
    "        self.init_access_token()\n",
    "    # API 调用 url\n",
    "    url = \"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}\".format(self.access_token)\n",
    "    # 配置 POST 参数\n",
    "    payload = json.dumps({\n",
    "        \"messages\": [\n",
    "            {\n",
    "                \"role\": \"user\",# user prompt\n",
    "                \"content\": \"{}\".format(prompt)# 输入的 prompt\n",
    "            }\n",
    "        ],\n",
    "        'temperature' : self.temperature\n",
    "    })\n",
    "    headers = {\n",
    "        'Content-Type': 'application/json'\n",
    "    }\n",
    "    # 发起请求\n",
    "    response = requests.request(\"POST\", url, headers=headers, data=payload, timeout=self.request_timeout)\n",
    "    if response.status_code == 200:\n",
    "        # 返回的是一个 Json 字符串\n",
    "        js = json.loads(response.text)\n",
    "        return js[\"result\"]\n",
    "    else:\n",
    "        return \"请求失败\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们还需要定义一下模型的描述方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 首先定义一个返回默认参数的方法\n",
    "@property\n",
    "def _default_params(self) -> Dict[str, Any]:\n",
    "    \"\"\"获取调用Ennie API的默认参数。\"\"\"\n",
    "    normal_params = {\n",
    "        \"temperature\": self.temperature,\n",
    "        \"request_timeout\": self.request_timeout,\n",
    "        }\n",
    "    return {**normal_params}\n",
    "\n",
    "\n",
    "@property\n",
    "def _identifying_params(self) -> Mapping[str, Any]:\n",
    "    \"\"\"Get the identifying parameters.\"\"\"\n",
    "    return {**{\"model_name\": self.model_name}, **self._default_params}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过上述步骤，我们就可以基于 LangChain 定义百度文心的调用方式了。我们将此代码封装在 wenxin_llm.py 文件中，将在讲述如何调用百度文心的 Notebook 中直接使用该 LLM。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
